/***************************************************************************************************
 * Copyright (c) 2017-2020, NVIDIA CORPORATION.  All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 *modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright notice,
 *this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *notice, this list of conditions and the following disclaimer in the
 *documentation and/or other materials provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the names of its
 *contributors may be used to endorse or promote products derived from this
 *software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT,
 *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
 *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING
 *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
 *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 **************************************************************************************************/
/*! \file
    \brief Matrix multiply
*/

#pragma once

#include <cuda_fp16.h>

#include "cutlass/arch/mma.h"

#include "cutlass/layout/matrix.h"

/////////////////////////////////////////////////////////////////////////////////////////////////

namespace cutlass {
namespace arch {

/////////////////////////////////////////////////////////////////////////////////////////////////

/// Matrix multiply-add operation
template <typename LayoutA, typename LayoutB, typename LayoutC>
struct Mma<gemm::GemmShape<2, 1, 1>, 1, half_t, LayoutA, half_t, LayoutB,
           half_t, LayoutC, OpMultiplyAdd> {
    using Shape = gemm::GemmShape<2, 1, 1>;
    using Operator = OpMultiplyAdd;

    CUTLASS_HOST_DEVICE
    void operator()(Array<half_t, 2>& d, Array<half_t, 2> const& a,
                    Array<half_t, 1> const& b, Array<half_t, 2> const& c) {
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600))

        __half2 const& A = reinterpret_cast<__half2 const&>(a);
        __half2 B = __half2half2(reinterpret_cast<__half const&>(b));
        __half2 const& C = reinterpret_cast<__half2 const&>(c);

        __half2 D = __hfma2(A, B, C);

        d = reinterpret_cast<Array<half_t, 2>&>(D);

#else
        CUTLASS_PRAGMA_UNROLL
        for (int i = 0; i < 2; ++i) {
            d[i] = a[i] * b[0] + c[i];
        }
#endif
    }
};

/////////////////////////////////////////////////////////////////////////////////////////////////

/// Matrix multiply-add operation
template <typename LayoutA, typename LayoutB>
struct Mma<gemm::GemmShape<1, 2, 1>, 1, half_t, LayoutA, half_t, LayoutB,
           half_t, layout::RowMajor, OpMultiplyAdd> {
    using Shape = gemm::GemmShape<1, 2, 1>;
    using Operator = OpMultiplyAdd;

    CUTLASS_HOST_DEVICE
    void operator()(Array<half_t, 2>& d, Array<half_t, 1> const& a,
                    Array<half_t, 2> const& b, Array<half_t, 2> const& c) {
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600))

        __half2 const& A = __half2half2(reinterpret_cast<__half const&>(a));
        __half2 B = reinterpret_cast<__half2 const&>(b);
        __half2 const& C = reinterpret_cast<__half2 const&>(c);

        __half2 D = __hfma2(A, B, C);

        d = reinterpret_cast<Array<half_t, 2>&>(D);

#else
        CUTLASS_PRAGMA_UNROLL
        for (int i = 0; i < 2; ++i) {
            d[i] = a[0] * b[i] + c[i];
        }
#endif
    }
};

/////////////////////////////////////////////////////////////////////////////////////////////////

/// Matrix multiply-add operation
template <>
struct Mma<gemm::GemmShape<2, 2, 1>, 1, half_t, layout::ColumnMajor, half_t,
           layout::RowMajor, half_t, layout::ColumnMajor, OpMultiplyAdd> {
    using Shape = gemm::GemmShape<2, 2, 1>;
    using Operator = OpMultiplyAdd;

    CUTLASS_HOST_DEVICE
    void operator()(Array<half_t, 4>& d, Array<half_t, 2> const& a,
                    Array<half_t, 2> const& b, Array<half_t, 4> const& c) {
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600))

        __half2 const& A = reinterpret_cast<__half2 const&>(a);
        __half2 Blo = __low2half2(reinterpret_cast<__half2 const&>(b));
        __half2 Bhi = __high2half2(reinterpret_cast<__half2 const&>(b));

        __half2 const* C = reinterpret_cast<__half2 const*>(&c);

        __half2 Dlo = __hfma2(A, Blo, C[0]);
        __half2 Dhi = __hfma2(A, Bhi, C[1]);

        Array<half_t, 2>* D = reinterpret_cast<Array<half_t, 2>*>(&d);

        D[0] = reinterpret_cast<Array<half_t, 2> const&>(Dlo);
        D[1] = reinterpret_cast<Array<half_t, 2> const&>(Dhi);

#else
        CUTLASS_PRAGMA_UNROLL
        for (int j = 0; j < 2; ++j) {
            CUTLASS_PRAGMA_UNROLL
            for (int i = 0; i < 2; ++i) {
                d[i + 2 * j] = a[i] * b[j] + c[i + 2 * j];
            }
        }
#endif
    }
};

/////////////////////////////////////////////////////////////////////////////////////////////////

/// Matrix multiply-add operation
template <>
struct Mma<gemm::GemmShape<2, 2, 1>, 1, half_t, layout::ColumnMajor, half_t,
           layout::RowMajor, half_t, layout::RowMajor, OpMultiplyAdd> {
    using Shape = gemm::GemmShape<2, 2, 1>;
    using Operator = OpMultiplyAdd;

    CUTLASS_HOST_DEVICE
    void operator()(Array<half_t, 4>& d, Array<half_t, 2> const& a,
                    Array<half_t, 2> const& b, Array<half_t, 4> const& c) {
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600))

        __half2 Alo = __low2half2(reinterpret_cast<__half2 const&>(a));
        __half2 Ahi = __high2half2(reinterpret_cast<__half2 const&>(a));
        __half2 const& B = reinterpret_cast<__half2 const&>(b);

        __half2 const* C = reinterpret_cast<__half2 const*>(&c);

        __half2 Dlo = __hfma2(Alo, B, C[0]);
        __half2 Dhi = __hfma2(Ahi, B, C[0]);

        Array<half_t, 2>* D = reinterpret_cast<Array<half_t, 2>*>(&d);

        D[0] = reinterpret_cast<Array<half_t, 2>&>(Dlo);
        D[1] = reinterpret_cast<Array<half_t, 2>&>(Dhi);
#else
        CUTLASS_PRAGMA_UNROLL
        for (int i = 0; i < 2; ++i) {
            CUTLASS_PRAGMA_UNROLL
            for (int j = 0; j < 2; ++j) {
                d[i * 2 + j] = a[i] * b[j] + c[i * 2 + j];
            }
        }
#endif
    }
};

/////////////////////////////////////////////////////////////////////////////////////////////////

}  // namespace arch
}  // namespace cutlass
